import json
import cv2
import requests
from flask import request
from flask_restplus import Namespace, Resource, reqparse
from flask_login import login_required, current_user
from shapely.geometry import Point, Polygon

import base64
from celery.result import AsyncResult

from database import AnnotationModel, ImageModel, CategoryModel

from ..util import query_util
from ..util.shapely_util import annotations_containing_point, point_in_segmentation

from workers import celery
from workers.tasks import batch

import datetime
import logging

logger = logging.getLogger('gunicorn.error')

api = Namespace('annotation', description='Annotation related operations')

create_annotation = reqparse.RequestParser()
create_annotation.add_argument(
    'image_id', type=int, required=True, location='json')
create_annotation.add_argument('category_id', type=int, location='json')
create_annotation.add_argument('isbbox', type=bool, location='json')
create_annotation.add_argument('metadata', type=dict, location='json')
create_annotation.add_argument('segmentation', type=list, location='json')
create_annotation.add_argument('keypoints', type=list, location='json')
create_annotation.add_argument('color', location='json')

create_image_id = reqparse.RequestParser()
create_image_id.add_argument('image_id', type=int, required=True)

create_auto = reqparse.RequestParser()
create_auto.add_argument(
    'image_id', type=int, required=True, location='json')
create_auto.add_argument('category_id', type=int, location='json')
create_auto.add_argument('isbbox', type=bool, location='json')
create_auto.add_argument('metadata', type=dict, location='json')
create_auto.add_argument('segmentation', type=list, location='json')
create_auto.add_argument('keypoints', type=list, location='json')
create_auto.add_argument('color', location='json')
create_auto.add_argument('x', type=int, location='json')
create_auto.add_argument('y', type=int, location='json')

create_roi = reqparse.RequestParser()
create_roi.add_argument(
    'image_id', type=int, required=True, location='json')
create_roi.add_argument('category_id', type=int, location='json')
create_roi.add_argument('isbbox', type=bool, location='json')
create_roi.add_argument('metadata', type=dict, location='json')
create_roi.add_argument('segmentation', type=list, location='json')
create_roi.add_argument('keypoints', type=list, location='json')
create_roi.add_argument('color', location='json')
create_roi.add_argument('x', type=float, location='json', required=True)
create_roi.add_argument('y', type=float, location='json', required=True)
create_roi.add_argument('width', type=float, location='json', required=True)
create_roi.add_argument('height', type=float, location='json', required=True)

# 批量标注
create_batch = reqparse.RequestParser()
create_batch.add_argument('dataset_id', type=int, location='json')
create_batch.add_argument('category_ids', type=list, location='json')
create_batch.add_argument('metadata', type=dict, location='json')
create_batch.add_argument('box_treshold', type=float, location='json')
create_batch.add_argument('text_treshold', type=float, location='json')

update_annotation = reqparse.RequestParser()
update_annotation.add_argument('category_id', type=int, location='json')

annotation_contains_this_point = reqparse.RequestParser()
annotation_contains_this_point.add_argument('image_id', type=int, required=True)
annotation_contains_this_point.add_argument('x', type=float, required=True)
annotation_contains_this_point.add_argument('y', type=float, required=True)

annotation_by_segmentation = reqparse.RequestParser()
annotation_by_segmentation.add_argument('image_id', type=int, required=True)
annotation_by_segmentation.add_argument('annotation_id', type=int, required=True)
annotation_by_segmentation.add_argument('category_id', type=int, required=True)


@api.route('/yolo')
class AnnotationYolo(Resource):
    @api.expect(create_annotation)
    def post(self):
        args = create_annotation.parse_args()
        image_id = args.get('image_id')
        category_id = args.get('category_id')
        image_model = ImageModel.objects(id=image_id).first()
        if image_model is None:
            return {'message': 'Image not found'}, 500

        request_url = "http://125.220.157.228:29992/predict"
        if image_model.path is not None:
            image_path = image_model.path
            with open(image_path, 'rb') as image:
                files = {'image': image}
                try:
                    response = requests.post(request_url, files=files)
                    response.raise_for_status()  # 检查请求是否成功
                    result = response.json()
                    annotations = result['coco']['annotations']
                    new_annotations = []
                    for annotation in annotations:
                        try:
                            x = annotation['bbox'][0]
                            y = annotation['bbox'][1]
                            w = annotation['bbox'][2] - x
                            h = annotation['bbox'][3] - y
                            new_annotation = AnnotationModel(
                                image_id=image_id,
                                category_id=category_id,
                                metadata={},
                                segmentation=[
                                    [
                                        x, y,
                                        x + w, y,
                                        x + w, y + h,
                                        x, y + h
                                    ]
                                ],
                                bbox=[x, y, w, h],
                                area=0,
                                keypoints=annotation['keypoints'],
                                isbbox=True
                            )
                            new_annotation.save()
                            new_annotations.append(query_util.fix_ids(new_annotation))
                        except (ValueError, TypeError) as e:
                            return {'message': str(e)}, 400

                    return {'annotations': new_annotations}, 200
                except requests.exceptions.RequestException as e:
                    return {'message': f"服务器没开 {e}"}, 888


@api.route('/ocr')
class AnnotationOcr(Resource):
    @api.expect(create_batch)
    def post(self):
        args = create_annotation.parse_args()
        image_id = args.get('image_id')

        image_model = ImageModel.objects(id=image_id).first()
        if image_model is None:
            return {'message': 'Image not found'}, 404

        request_url = "https://aip.baidubce.com/rest/2.0/ocr/v1/general"
        access_token = '24.089d6ca2f96dbdf05b3b1d8b4363673f.2592000.1714044728.282335-58316360'

        if image_model.path is not None:
            image_path = image_model.path
            with open(image_path, 'rb') as image:
                img = base64.b64encode(image.read())
                request_url = request_url + "?access_token=" + access_token
                params = {"image": img}
                headers = {'content-type': 'application/x-www-form-urlencoded'}
                response = requests.post(request_url, data=params, headers=headers)
                if response:
                    result = response.json()
        detections = result['words_result']
        annotations = []
        category_model = CategoryModel.objects(name='text').first()
        if category_model is None:
            category_model = CategoryModel(name='text')
            category_model.save()
        for detection in detections:
            if detection['words'] is not None:
                location = detection['location']
                bbox = [location['left'], location['top'], location['width'], location['height']]
                annotation = AnnotationModel(
                    image_id=image_id,
                    category_id=category_model.id,
                    segmentation=[
                        [
                            location['left'] + location['width'],
                            location['top'],
                            location['left'] + location['width'],
                            location['top'] + location['height'],
                            location['left'],
                            location['top'] + location['height'],
                            location['left'],
                            location['top']
                        ]
                    ],
                    bbox=bbox,
                    creator="ai",
                    metadata={"name": detection['words']},
                    area=(location['width'] * location['height']),
                    isbbox=True
                )
                annotation.save()
                annotations.append(query_util.fix_ids(annotation))

        return {'annotations': annotations}, 200


@api.route('/batch')
class AnnotationBatch(Resource):
    @api.expect(create_batch)
    @login_required
    def post(self):
        args = create_batch.parse_args()
        category_ids = args.get('category_ids', [])
        dataset_id = args.get('dataset_id')
        metadata = args.get('metadata', {})
        box_treshold = args.get('box_treshold')
        text_treshold = args.get('text_treshold')

        # 首先根据image_ids获取image数组
        # logger.info("==========")
        # images = []
        # for image_id in image_ids:
        #     image_model = ImageModel.objects(id=image_id).first()
        #     with open(image_model.path, 'rb') as image:
        #         images.append(image)

        result = batch.delay(dataset_id, category_ids, metadata, box_treshold, text_treshold)

        # get result
        async_result = AsyncResult(id=result.id, app=celery)

        if async_result.successful():
            result = async_result.get()
            logger.info(result)
            # result.forget() # 将结果删除,执行完成，结果不会自动删除
            # async.revoke(terminate=True)  # 无论现在是什么时候，都要终止
            # async.revoke(terminate=False) # 如果任务还没有开始执行呢，那么就可以终止。
        elif async_result.failed():
            logger.error('执行失败')
        elif async_result.status == 'PENDING':
            logger.warning('任务等待中被执行')
        elif async_result.status == 'RETRY':
            logger.warning('任务异常后正在重试')
        elif async_result.status == 'STARTED':
            logger.warning('任务已经开始被执行')
        if (async_result.successful()):
            r = async_result.get()
            logger.info(r)


# 调用whole_image分割图片后根据annotation生成新的annotation
@api.route('/annotation_by_segmentation')
class AnnotationByWholeImage(Resource):
    @api.expect(annotation_by_segmentation)
    @login_required
    def post(self):
        args = annotation_by_segmentation.parse_args()
        logger.info("h" * 100)
        image_id = args.get('image_id')
        annotation_id = args.get('annotation_id')
        category_id = args.get('category_id')
        logger.info("h" * 100)
        if image_id is None or annotation_id is None or category_id is None:
            return {'message': 'image_id or annotation_id or category_id is None'}, 400
        else:
            try:
                logger.info("hello"*10)
                logger.info("image_id" + str(image_id))
                annotation_model = AnnotationModel.objects(id=annotation_id).first()
                # 创建新的annotation
                new_annotation = AnnotationModel(
                    image_id=image_id,
                    category_id=category_id,
                    segmentation=annotation_model.segmentation,
                    bbox=annotation_model.bbox,
                    dateset_id=annotation_model.dataset_id,
                    color=annotation_model.color,
                    metadata=annotation_model.metadata,
                    creator=annotation_model.creator,
                    area=annotation_model.area,
                    deleted=False,
                    paper_object=annotation_model.paper_object,
                    isbbox=annotation_model.isbbox
                )
                new_annotation.save()
                return {'annotation': query_util.fix_ids(new_annotation)}, 200
            except Exception as e:
                return {'message': str(e)}, 400


@api.route('/auto/whole_image')
class AnnotationAutoImage(Resource):
    @api.expect(create_image_id)
    def post(self):
        args = create_image_id.parse_args()
        image_id = args.get('image_id')

        image = ImageModel.objects(id=image_id).first()

        existing_category = CategoryModel.objects(name="category_used_to_enhance").first()
        if existing_category:
            existing_category_id = CategoryModel.objects(name="category_used_to_enhance").first().id
            annotation_list = AnnotationModel.objects(image_id=image_id, category_id=existing_category_id, deleted=True)
            for annotation in annotation_list:
                annotation.deleted = False
                annotation.save()

        else:
            category_used_to_enhance = CategoryModel(
                name="category_used_to_enhance",
                deleted=False
            )
            category_used_to_enhance.save()
            category_used_to_enhance_id = CategoryModel.objects(name="category_used_to_enhance").first().id

            # logger.info("category_used_to_enhance_id"+str(category_used_to_enhance_id))

            image.categories.append(category_used_to_enhance_id)
        category_used_to_enhance_id = CategoryModel.objects(name="category_used_to_enhance").first().id

        existing_annotation = AnnotationModel.objects(image_id=image_id,
                                                      category_id=category_used_to_enhance_id,
                                                      deleted=False).first()
        if existing_annotation:
            logger.info(f"Annotations for image {image_id} already exist.")
            return {"message": "whole image already been processed."}, 200

        with open(image.path, 'rb') as image_file:
            files = {'image': image_file}
            remote_flask_url = 'http://125.220.157.228:29991/'
            r = requests.post(remote_flask_url, files=files)

        response = r.json()

        # annotations = response['annotations']
        annotations = response['coco']['annotations']
        new_annotations = []
        for annotation in annotations:
            segmentation_lists = annotation['segmentation']
            # 计算每个列表的长度
            lengths = [len(lst) for lst in segmentation_lists]
            # 找到最大长度的索引
            max_index = lengths.index(max(lengths))
            # 最长的列表
            longest_segmentation = segmentation_lists[max_index]
            longest_segmentation = [longest_segmentation]
            # logger.info("segmentation" + str(longest_segmentation))
            area = annotation['area']
            bbox = annotation['bbox']
            try:
                new_annotation = AnnotationModel(
                    image_id=image_id,
                    category_id=category_used_to_enhance_id,
                    metadata={},
                    segmentation=longest_segmentation,
                    bbox=bbox,
                    area=area,
                    keypoints=[],
                    isbbox=False
                )
                new_annotation.save()
                new_annotations.append(query_util.fix_ids(new_annotation))
            except (ValueError, TypeError) as e:
                return {'message': str(e)}, 400

        res = {
            "annotations": new_annotations
        }
        # existing_category_id = CategoryModel.objects(name="category_used_to_enhance").first().id
        # annotation_list = AnnotationModel.objects(image_id=image_id, category_id=existing_category_id, deleted=False)
        # for annotation in annotation_list:
        #     annotation.deleted = True
        #     annotation.save()
        # return json.dumps(res)
        return res, 200


@api.route('/auto/roi')
class AnnotationAutoRoi(Resource):
    @api.expect(create_roi)
    def post(self):
        args = create_roi.parse_args()
        image_id = args.get('image_id')
        category_id = args.get('category_id')
        isbbox = args.get('isbbox')
        metadata = args.get('metadata', {})
        keypoints = args.get('keypoints', [])
        width = args.get('width')
        height = args.get('height')
        x = args.get('x')
        y = args.get('y')

        image = ImageModel.objects(id=image_id).first()

        with open(image.path, 'rb') as image_file:
            files = {'image': image_file}
            remote_flask_url = 'http://125.220.157.228:29991/roi'
            data = {
                'x': x,
                'y': y,
                'width': width,
                'height': height
            }
            logger.info(data)
            try:
                r = requests.post(remote_flask_url, data=data, files=files)
            except requests.exceptions.RequestException as e:
                return {'message': f"服务器没开 {e}"}, 888
        logger.info(r.json())
        response = r.json()
        annotations = response['annotations']

        new_annotations = []
        for annotation in annotations:
            segmentation_lists = annotation['segmentation']
            # 计算每个列表的长度
            lengths = [len(lst) for lst in segmentation_lists]
            # 找到最大长度的索引
            max_index = lengths.index(max(lengths))
            # 最长的列表
            longest_segmentation = segmentation_lists[max_index]
            longest_segmentation = [longest_segmentation]
            # segmentation = annotation['segmentation']
            area = annotation['area']
            bbox = annotation['bbox']
            try:
                new_annotation = AnnotationModel(
                    image_id=image_id,
                    category_id=category_id,
                    metadata=metadata,
                    segmentation=longest_segmentation,
                    bbox=bbox,
                    area=area,
                    keypoints=keypoints,
                    isbbox=isbbox
                )
                new_annotation.save()
                new_annotations.append(query_util.fix_ids(new_annotation))
            except (ValueError, TypeError) as e:
                return {'message': str(e)}, 400
        res = {
            "annotations": new_annotations
        }
        # return json.dumps(res)
        return res, 200


@api.route('/auto')
class AnnotationAuto(Resource):

    @api.expect(create_auto)
    def post(self):
        args = create_auto.parse_args()
        image_id = args.get('image_id')
        category_id = args.get('category_id')
        isbbox = args.get('isbbox')
        metadata = args.get('metadata', {})
        keypoints = args.get('keypoints', [])
        x = args.get('x')
        y = args.get('y')

        image = ImageModel.objects(id=image_id).first()

        with open(image.path, 'rb') as image_file:
            files = {'image': image_file}
            remote_flask_url = 'http://125.220.157.228:29991/single'
            data = {
                'x': x,
                'y': y
            }
            try:
                r = requests.post(remote_flask_url, data=data, files=files)
            except requests.exceptions.RequestException as e:
                return {'message': f"服务器没开 {e}"}, 888
        logger.info(r.json())
        response = r.json()
        annotations = response['annotations']
        new_annotations = []
        for annotation in annotations:
            segmentation_lists = annotation['segmentation']
            # 计算每个列表的长度
            lengths = [len(lst) for lst in segmentation_lists]
            # 找到最大长度的索引
            max_index = lengths.index(max(lengths))
            # 最长的列表
            longest_segmentation = segmentation_lists[max_index]
            longest_segmentation = [longest_segmentation]
            # segmentation = annotation['segmentation']
            # logger.info(segmentation)
            area = annotation['area']
            bbox = annotation['bbox']
            try:
                new_annotation = AnnotationModel(
                    image_id=image_id,
                    category_id=category_id,
                    metadata=metadata,
                    segmentation=longest_segmentation,
                    bbox=bbox,
                    area=area,
                    keypoints=keypoints,
                    isbbox=isbbox
                )
                new_annotation.save()
                new_annotations.append(query_util.fix_ids(new_annotation))
            except (ValueError, TypeError) as e:
                return {'message': str(e)}, 400
        res = {
            "annotations": new_annotations
        }
        # return json.dumps(res)
        return res, 200


@api.route('/')
class Annotation(Resource):

    @login_required
    def get(self):
        """ Returns all annotations """
        return query_util.fix_ids(current_user.annotations.exclude("paper_object").all())

    @api.expect(create_annotation)
    @login_required
    def post(self):
        """ Creates an annotation """
        args = create_annotation.parse_args()
        image_id = args.get('image_id')
        category_id = args.get('category_id')
        isbbox = args.get('isbbox')
        metadata = args.get('metadata', {})
        segmentation = args.get('segmentation', [])
        keypoints = args.get('keypoints', [])

        image = current_user.images.filter(id=image_id, deleted=False).first()
        if image is None:
            return {"message": "Invalid image id"}, 400

        logger.info(
            f'{current_user.username} has created an annotation for image {image_id} with {isbbox}')
        logger.info(
            f'{current_user.username} has created an annotation for image {image_id}')

        try:
            annotation = AnnotationModel(
                image_id=image_id,
                category_id=category_id,
                metadata=metadata,
                segmentation=segmentation,
                keypoints=keypoints,
                isbbox=isbbox
            )
            annotation.save()
        except (ValueError, TypeError) as e:
            return {'message': str(e)}, 400

        return query_util.fix_ids(annotation)


@api.route('/<int:annotation_id>')
class AnnotationId(Resource):

    @login_required
    def get(self, annotation_id):
        """ Returns annotation by ID """
        annotation = current_user.annotations.filter(id=annotation_id).first()

        if annotation is None:
            return {"message": "Invalid annotation id"}, 400

        return query_util.fix_ids(annotation)

    @login_required
    def delete(self, annotation_id):
        """ Deletes an annotation by ID """
        annotation = current_user.annotations.filter(id=annotation_id).first()

        if annotation is None:
            return {"message": "Invalid annotation id"}, 400

        image = current_user.images.filter(
            id=annotation.image_id, deleted=False).first()
        image.flag_thumbnail()

        annotation.update(set__deleted=True,
                          set__deleted_date=datetime.datetime.now())
        return {'success': True}

    @api.expect(update_annotation)
    @login_required
    def put(self, annotation_id):
        """ Updates an annotation by ID """
        annotation = current_user.annotations.filter(id=annotation_id).first()

        if annotation is None:
            return {"message": "Invalid annotation id"}, 400

        args = update_annotation.parse_args()

        new_category_id = args.get('category_id')
        annotation.update(category_id=new_category_id)
        logger.info(
            f'{current_user.username} has updated category for annotation (id: {annotation.id})'
        )
        newAnnotation = current_user.annotations.filter(id=annotation_id).first()
        return query_util.fix_ids(newAnnotation)


# @api.route('/<int:annotation_id>/mask')
# class AnnotationMask(Resource):
#     def get(self, annotation_id):
#         """ Returns the binary mask of an annotation """
#         return query_util.fix_ids(AnnotationModel.objects(id=annotation_id).first())


@api.route('/find_annotation_which_has_this_point')
class AnnotationContainsThisPoint(Resource):
    @api.expect(annotation_contains_this_point)
    def post(self):
        args = annotation_contains_this_point.parse_args()
        image_id = args.get('image_id')
        x = args.get('x')
        y = args.get('y')

        # existing_category = CategoryModel.objects(name="category_used_to_enhance").first()
        # if existing_category:
        #     existing_category_id = CategoryModel.objects(name="category_used_to_enhance").first().id
        #     annotation_list = AnnotationModel.objects(image_id=image_id, category_id=existing_category_id, deleted=True)
        #     for annotation in annotation_list:
        #         annotation.deleted = False
        #         annotation.save()

        annotations = AnnotationModel.objects(image_id=image_id, deleted=False)

        random_annotation_contains_point = annotations_containing_point((x, y), annotations)

        if random_annotation_contains_point == None:
            return None

        if random_annotation_contains_point == None:
            return None

        return query_util.fix_ids(random_annotation_contains_point)
